def train_data_generator(normalization):
    import numpy as np
    import random
    import os
    from astropy.io import fits

    from Global_variable_setting import image_size, patch_size
    from Global_variable_setting import dataset_kind
    from Global_variable_setting import training_dataset_length, test_dataset_length

    # preparing function for normalization
    def normalize(data):
        size_of_patch_one_dimension = image_size // patch_size
        for i in range(size_of_patch_one_dimension):
            for j in range(size_of_patch_one_dimension):
                local_area = data[i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size]
                if np.max(local_area) != 0:
                    local_area = (local_area - np.min(local_area))/(np.max(local_area)-np.min(local_area))
                    data[i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size] = local_area
                else:
                    continue

    # wfsframes_path = '/home/zhangqingyang/PycharmProjects/soapy_test/ao_configuration/2022-06-05-22-30-34/wfsFPFrames'
    # dmcommands_file_path = '/home/zhangqingyang/PycharmProjects/soapy_test/ao_configuration/2022-06-05-22-30-34/dmCommands.fits'
    # dmcommands = fits.open(dmcommands_file_path)[0].data
    word_dir = os.getcwd()
    trainging_data_dir = word_dir + '/training_data'
    inputs = []
    outputs = []
    i = 0
    while i < (training_dataset_length + test_dataset_length):
        if dataset_kind == 'normal':
            sh_frame = fits.open(trainging_data_dir + '/' + str(i) + '/distorted_detector.fits')[0].data
        else:
            sh_frame = fits.open(trainging_data_dir + '/' + str(i) + '/perfect_detector.fits')[0].data
        sh_frame = sh_frame.astype('float32')
        if normalization:
            normalize(sh_frame)
        sh_frame = sh_frame[..., np.newaxis]
        dmcommand = fits.open(trainging_data_dir + '/' + str(i) + '/perfect_acts.fits')[0].data
        dmcommand = dmcommand.astype('float64')
        inputs.append(sh_frame)
        outputs.append(dmcommand)
        print(f'Load data successfully {i}')
        i += 1
    data = list(zip(inputs, outputs))
    random.shuffle(data)
    print("Data generated complete")
    return data
